#include <mpi.h>
#include <assert.h>

#include "coul_msm.h"
#include "comm.h"
#include "va_arg_vscode_fix.h"
#include "memory_cpp.hpp"
void comm_init(mpp_t *mpp, MPI_Comm comm, box<real> * gbox) {
  mpp->comm = comm;
  int pid, nproc;
  // puts("before size");
  MPI_Comm_size(comm, &nproc);
  MPI_Comm_rank(comm, &pid);
  // puts("after size");
  mpp->nproc = nproc;
  mpp->pid = pid;

  boxcpy(mpp->gbox, *gbox);
  vec<real> glen;
  vecsubv(glen, gbox->hi, gbox->lo);
  real best_area = (glen.x * glen.y + glen.y * glen.z + glen.z * glen.x) * 2;
  int nx = 0, ny, nz;
  for (int try_x = 1; try_x <= nproc; try_x++) {
    if (nproc % try_x == 0) {
      int nyz = nproc / try_x;
      for (int try_y = 1; try_y <= nyz; try_y++) {
        if (nproc % try_y == 0) {
          int try_z = nyz / try_y;
          real areaxy = glen.x * glen.y / (try_x * try_y);
          real areayz = glen.y * glen.z / (try_y * try_z);
          real areaxz = glen.z * glen.x / (try_z * try_x);
          real area = areaxy + areayz + areaxz;
          if (area < best_area) {
            best_area = area;
            nx = try_x;
            ny = try_y;
            nz = try_z;
          }
        }
      }
    }
  }
  assert(nx > 0);
  vecset3(mpp->dim, nx, ny, nz);
  int px = pid / (ny * nz);
  int py = (pid / nz) % ny;
  int pz = pid % nz;
  vecset3(mpp->loc, px, py, pz);

  vec<real> llen;
  vecdivv(llen, glen, mpp->dim);
  veccpy(mpp->llen, llen);
  vecmuladdv(mpp->lbox.lo, llen, mpp->loc, gbox->lo);
  vec<int> locp1;
  vecadd(locp1, mpp->loc, 1);
  vecmuladdv(mpp->lbox.hi, llen, locp1, gbox->lo);
  if (px == nx - 1)
    mpp->lbox.hi.x = gbox->hi.x;
  if (py == ny - 1)
    mpp->lbox.hi.y = gbox->hi.y;
  if (pz == nz - 1)
    mpp->lbox.hi.z = gbox->hi.z;

  int xnext = (px + 1) % nx;
  int ynext = (py + 1) % ny;
  int znext = (pz + 1) % nz;
  mpp->next.x = (xnext * ny + py) * nz + pz;
  mpp->next.y = (px * ny + ynext) * nz + pz;
  mpp->next.z = (px * ny + py) * nz + znext;

  int xprev = (px - 1 + nx) % nx;
  int yprev = (py - 1 + ny) % ny;
  int zprev = (pz - 1 + nz) % nz;
  mpp->prev.x = (xprev * ny + py) * nz + pz;
  mpp->prev.y = (px * ny + yprev) * nz + pz;
  mpp->prev.z = (px * ny + py) * nz + zprev;
}

void comm_init_buf(mpp_t *mpp, cellgrid_t *grid) {
  int max_comm_cells = 0;
  if (grid->nall.x * grid->nall.y > max_comm_cells)
    max_comm_cells = grid->nall.x * grid->nall.y;
  if (grid->nall.x * grid->nall.z > max_comm_cells)
    max_comm_cells = grid->nall.x * grid->nall.z;
  if (grid->nall.z * grid->nall.y > max_comm_cells)
    max_comm_cells = grid->nall.z * grid->nall.y;
  size_t max_comm_size = max_comm_cells * sizeof(celldata_t);
  mpp->send_prev = esmd::allocate<char>(max_comm_size, "comm/send buf/prev");
  mpp->send_next = esmd::allocate<char>(max_comm_size, "comm/send buf/next");
  mpp->recv_prev = esmd::allocate<char>(max_comm_size, "comm/recv buf/prev");
  mpp->recv_next = esmd::allocate<char>(max_comm_size, "comm/recv buf/next");

  mpp->max_comm_size = max_comm_size;
}
#define CONFIG_SWIB
void comm_reduce(void *buf, int count, MPI_Datatype type, MPI_Op op, mpp_t *mpp) {

  if (mpp->pid == 0) {
    MPI_Reduce(MPI_IN_PLACE, buf, count, type, op, 0, mpp->comm);
  } else {
    MPI_Reduce(buf, NULL, count, type, op, 0, mpp->comm);
  }
}
void comm_allreduce(void *buf, int count, MPI_Datatype type, MPI_Op op, mpp_t *mpp) {
  MPI_Allreduce(MPI_IN_PLACE, buf, count, type, op, mpp->comm);
}
#include <stdarg.h>
// void comm_vreduce(int count, MPI_Datatype type, MPI_Op op, mpp_t *mpp, ...){
//   va_list vl;
//   va_start(vl,mpp);
//   if (type == MPI_INT || type == MPI_FLOAT) {
//     int buf[count];
//     int *ptr[count];
//     for (int i = 0; i < count; i ++) {
//       buf[i] = *(int*)va_arg(vl,int*);

//     }
//     comm_reduce(buf, count, type, op, mpp);
//     if (mpp->pid == 0) {

//     }
//   } else if (type == MPI_DOUBLE || type == MPI_LONG){
//     long buf[count];
//     for (int i = 0; i < count; i ++) {
//       buf[i] = *(long*)va_arg(vl,long*);
//     }
//     comm_reduce(buf, count, type, op, mpp);
//   } else {
//     int size;
//     MPI_Type_size(type, &size);
//     char buf[count * size];
//     for (int i = 0; i < count; i ++) {
//       memcpy(buf + i * size, va_arg(vl,void*), size);
//     }
//     comm_reduce(buf, count, type, op, mpp);
//   }
// }
void comm_vreduce(int count, MPI_Datatype type, MPI_Op op, mpp_t *mpp, ...) {
  va_list vl;
  va_start(vl, mpp);
  int typesize;
  MPI_Type_size(type, &typesize);
  if (typesize == 4) {
    int buf[count];
    int *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, int*);
      buf[i] = *ptr[i];
    }
    comm_allreduce(buf, count, type, op, mpp);
    if (mpp->pid == 0) {
      for (int i = 0; i < count; i++) {
        *ptr[i] = buf[i];
      }
    }
  } else if (typesize == 8) {
    long buf[count];
    long *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, long *);
      buf[i] = *ptr[i];
    }
    comm_allreduce(buf, count, type, op, mpp);
    if (mpp->pid == 0) {
      for (int i = 0; i < count; i++) {
        *ptr[i] = buf[i];
      }
    }
  } else {
    char buf[count * typesize];
    void *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, void *);
      memcpy(buf + i * typesize, ptr[i], typesize);
    }
    comm_allreduce(buf, count, type, op, mpp);
    if (mpp->pid == 0) {
      for (int i = 0; i < count; i++) {
        memcpy(ptr[i], buf + i * typesize, typesize);
      }
    }
  }
}
void comm_vallreduce(int count, MPI_Datatype type, MPI_Op op, mpp_t *mpp, ...) {
  va_list vl;
  va_start(vl, mpp);
  int typesize;
  MPI_Type_size(type, &typesize);
  if (typesize == 4) {
    int buf[count];
    int *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, int *);
      buf[i] = *ptr[i];
    }
    comm_allreduce(buf, count, type, op, mpp);
    for (int i = 0; i < count; i++) {
      *ptr[i] = buf[i];
    }
  } else if (typesize == 8) {
    long buf[count];
    long *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, long *);
      buf[i] = *ptr[i];
    }
    comm_allreduce(buf, count, type, op, mpp);
    for (int i = 0; i < count; i++) {
      *ptr[i] = buf[i];
    }
  } else {
    char buf[count * typesize];
    void *ptr[count];
    for (int i = 0; i < count; i++) {
      ptr[i] = va_arg(vl, void *);
      memcpy(buf + i * typesize, ptr[i], typesize);
    }
    comm_allreduce(buf, count, type, op, mpp);
    for (int i = 0; i < count; i++) {
      memcpy(ptr[i], buf + i * typesize, typesize);
    }
  }
}

void comm_reduce_stat(mdstat_t *stat, mpp_t *mpp) {
  comm_reduce(stat, 13, MPI_DOUBLE, MPI_SUM, mpp);
}
void comm_allreduce_stat(mdstat_t *stat, mpp_t *mpp) {
  // printf("%f %f %f %f %f %f\n", stat->ecoul, stat->evdwl, stat->ebond, stat->eangle, stat->etori, stat->eimpr);
  comm_allreduce(stat, 13, MPI_DOUBLE, MPI_SUM, mpp);
  // printf("%f %f %f %f %f %f\n", stat->ecoul, stat->evdwl, stat->ebond, stat->eangle, stat->etori, stat->eimpr);
}
#include "buffer.hpp"
#define HAS & 1L <<
template<int PACK_MASK>
void pack_cell(pack_buffer *buf, celldata_t *cell){
  if (PACK_MASK HAS CF_N) {
    buf->append(cell->natom);
  }
  if (PACK_MASK HAS CF_X) {
    buf->append(cell->x, cell->natom);
  }
  if (PACK_MASK HAS CF_V) {
    buf->append(cell->v, cell->natom);
  }
  if (PACK_MASK HAS CF_F) {
    buf->append(cell->f, cell->natom);
  }
  if (PACK_MASK HAS CF_Q) {
    buf->append(cell->q, cell->natom);
  }
  if (PACK_MASK HAS CF_TAG) {
    buf->append(cell->tag, cell->natom);
  }
  if (PACK_MASK HAS CF_TYPE) {
    buf->append(cell->t, cell->natom);
  }
  if (PACK_MASK HAS CF_RIGID) {
    buf->append(cell->rigid, cell->natom);
  }
}


// #include "listed.hpp"
// template void forward_comm(mpp_t *mpp, cellgrid_t *grid, listed_force_cell_ents<harmonic_bond_param, tagint, MAX_BONDED_CELL> *cells, listed_force_cell_ents<harmonic_bond_param, tagint, MAX_BONDED_CELL>::pack_export&, listed_force_cell_ents<harmonic_bond_param, tagint, MAX_BONDED_CELL>::unpack_export&);
#include "comm.gen.h"
